-
Notifications
You must be signed in to change notification settings - Fork 22
Enable gfx950 CI on dev branch #401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
|
|
||
| // Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does unstable mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6192 - OperatorTest/GEMMTestSuite.Testfp8xfp8xbf16xbf16xbf16/2304x768x4096x0x0xTNxM # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
6768 - OperatorTest/GEMMTestSuite.Testfp8xbf8xbf16xbf16xfp16/2304x768x4096x0x0xTNxM # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
7344 - OperatorTest/GEMMTestSuite.Testbf8xfp8xbf16xbf16xfp32/2304x768x4096x0x0xTNxM # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
7488 - OperatorTest/GEMMTestSuite.Testbf8xfp8xbf16xbf16xfp16/2304x768x4096x0x0xTNxM # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
These testcases are failing at random, so we decided to skip for this mi350 bring up. When I tested on Rocm7.2 there was no issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard it with #if HIP_VERSION < 70200000 then. So comments about temporary disable and re-enable and mentioning of ROCm 7.2 can be removed
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
|
|
||
| // Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard it with #if HIP_VERSION < 70200000 then. So comments about temporary disable and re-enable and mentioning of ROCm 7.2 can be removed
| // Re-enable after ROCm 7.2 once hipBLASLt fixes land. | ||
| if (prop.major == 9 && prop.minor == 5 && | ||
| params.transa && !params.transb && | ||
| params.m == 2304 && params.k == 768 && params.n == 4096) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is only 1 size for DqTest. Instead of skipping the test just use different size for test_case_sizes_mxfp8, for example 768, 3072, 4096
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
|
rebase to dev |
…or gfx950 ci enablement
…ed with hipblaslt
5a83295 to
b551b3f
Compare
|
Test report for MI355 with Level=3:
|
Description
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Disable cudaGraph registration for JAX gemm and grouped_gemm FFI on ROCm to stop graph-capture hangs for gfx950 (transformer_engine/jax/csrc/extensions/gemm.cpp).
Keep is_fp8_gemm_with_all_layouts_supported false on gfx950 until hipBLASLt FP8 layout coverage is validated (transformer_engine/jax/quantize/device_utils.py).
Fix RMSNorm Triton kernel for misaligned row strides by only applying 16B alignment hints when the pointers/strides are aligned; this resolves test_norms dgamma mismatches and the test_transformer_layer_hidden_states_format numerics issues. Also relax fused-optimizer FP8 tolerances on MI350 (transformer_engine/pytorch/triton_kernels/rmsnorm.py, tests/pytorch/test_numerics.py, tests/pytorch/test_fused_optimizer.py).
Skip unsupported FP8 quantized linear combinations on gfx950 where hipBLASLt lacks algorithms (tests/pytorch/test_fusible_ops.py).
Add gfx950 detection helper and skip test_gpt_full_activation_recompute on MI350 configs that hipBLASLt cannot serve (transformer_engine/pytorch/utils.py, tests/pytorch/test_numerics.py).
Checklist: